{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "hypernerf_optim_latent",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "-QWb-snnOf1I"
      },
      "source": [
        "def apply_model(rng, state, batch):\n",
        "  fine_key, coarse_key = random.split(rng, 2)\n",
        "  model_out = model.apply(\n",
        "      {'params': state.optimizer.target['model']}, \n",
        "      batch,\n",
        "      extra_params=state.extra_params,\n",
        "      metadata_encoded=True,\n",
        "      rngs={'fine': fine_key, 'coarse': coarse_key})\n",
        "  return model_out\n",
        "\n",
        "\n",
        "def loss_fn(rng, state, target, batch):\n",
        "  batch['metadata'] = jax.tree_map(lambda x: x.reshape((1, -1)), \n",
        "                                   target['metadata'])\n",
        "  model_out = apply_model(rng, state, batch)['fine']\n",
        "  # loss = ((model_out['rgb'] - batch['rgb']) ** 2).mean(axis=-1)\n",
        "  loss = jnp.abs(model_out['rgb'] - batch['rgb']).mean(axis=-1)\n",
        "  return loss.mean()\n",
        "\n",
        "\n",
        "def optim_step(rng, state, optimizer, batch):\n",
        "  rng, key = random.split(rng, 2)\n",
        "  grad_fn = jax.value_and_grad(loss_fn, argnums=2)\n",
        "  loss, grad = grad_fn(key, state, optimizer.target, batch)\n",
        "  grad = jax.lax.pmean(grad, axis_name='batch')\n",
        "  loss = jax.lax.pmean(loss, axis_name='batch')\n",
        "\n",
        "  optimizer = optimizer.apply_gradient(grad)\n",
        "\n",
        "  return rng, loss, optimizer\n",
        "\n",
        "\n",
        "p_optim_step = jax.pmap(optim_step, axis_name='batch')\n",
        "\n",
        "key = random.PRNGKey(0)\n",
        "key = key + jax.process_index()\n",
        "keys = random.split(key, jax.local_device_count())\n",
        "\n",
        "optimizer_def = optim.Adam(0.1)\n",
        "init_metadata = evaluation.encode_metadata(\n",
        "  model, \n",
        "  jax_utils.unreplicate(state.optimizer.target['model']), \n",
        "  jax.tree_map(lambda x: x[0, 0], data['metadata']))\n",
        "# init_metadata = jax.tree_map(lambda x: x[0], init_metadata)\n",
        "# Initialize to zero.\n",
        "init_metadata = jax.tree_map(lambda x: jnp.zeros_like(x), init_metadata)\n",
        "optimizer = optimizer_def.create({'metadata': init_metadata})\n",
        "optimizer = jax_utils.replicate(optimizer, jax.local_devices())\n",
        "devices = jax.local_devices()\n",
        "batch_size = 1024\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gzqX4gTWOf3b"
      },
      "source": [
        "metadata_progression = []\n",
        "\n",
        "for i in range(25):\n",
        "  batch_inds = random.choice(keys[0], np.arange(train_data['rgb'].shape[0]), replace=False, shape=(batch_size,))\n",
        "  batch = jax.tree_map(lambda x: x[batch_inds, ...], train_data)\n",
        "  batch = datasets.prepare_data(batch)\n",
        "  keys, loss, optimizer = p_optim_step(keys, state, optimizer, batch)\n",
        "  loss = jax_utils.unreplicate(loss)\n",
        "  metadata_progression.append(jax.tree_map(lambda x: np.array(x), jax_utils.unreplicate(optimizer.target['metadata'])))\n",
        "  print(f'train_loss = {loss.item():.04f}')\n",
        "  del batch"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-s9pCeiqZhKi"
      },
      "source": [
        "frames = []\n",
        "for metadata in metadata_progression:\n",
        "# metadata = jax_utils.unreplicate(optimizer.target['metadata'])\n",
        "  camera = datasource.load_camera(target_id).scale(1.0)\n",
        "  batch = make_batch(camera, None, metadata['encoded_warp'], metadata['encoded_hyper'])\n",
        "  render = render_fn(state, batch, rng=rng)\n",
        "  pred_rgb = np.array(render['rgb'])\n",
        "  pred_depth_med = np.array(render['med_depth'])\n",
        "  pred_depth_viz = viz.colorize(1.0 / pred_depth_med.squeeze())\n",
        "  media.show_images([pred_rgb, pred_depth_viz])\n",
        "  frames.append({ \n",
        "      'rgb': pred_rgb,\n",
        "      'depth': pred_depth_med,\n",
        "  })\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tu8cC9gKe2BS"
      },
      "source": [
        "media.show_image(data['rgb'])\n",
        "media.show_videos([\n",
        "    [d['rgb'] for d in frames],\n",
        "    [viz.colorize(1/d['depth'].squeeze(), cmin=1.5, cmax=2.9) for d in frames],\n",
        "], fps=10)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EUpAEa4boPbw"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}